{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 添加 relay 算子\n",
    "\n",
    "参考：[relay_add_op](https://xinetzone.github.io/tvm/docs/dev/how_to/relay_add_op.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "属性是固定的参数，应该在编译时就知道。比如卷积算子的 `stride` 和 `expand` 是属于卷积算子属性节点的字段的适当的例子。属性一般定义在 `include/tvm/relay/attrs/` 文件夹中，也可以直接通过 Python 接口定义："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`register_op_attr` 函数用于注册算子的属性。\n",
    "\n",
    "参数：\n",
    "- `op_name`（str 类型）：算子的名称。\n",
    "- `attr_key`（str 类型）：属性的名称。\n",
    "- `value`（object 类型，可选）：要设置的值。\n",
    "- `level`（int 类型，可选）：优先级级别，默认为 `10`。\n",
    "\n",
    "返回值：\n",
    "- 如果 `value` 没有指定，则返回注册函数 `fregister`。\n",
    "- 如果 `value` 指定了，则直接返回注册函数 `_register(value)`。\n",
    "\n",
    "内部函数 `_register` 用于实际执行算子注册，它调用了 `_ffi_api.RegisterOpAttr` 函数来注册算子的属性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay\n",
    "from tvm.ir import register_op_attr\n",
    "from tvm.relay.op import op as _op, op_attrs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "op_name = \"add_one\"\n",
    "_op.register(op_name, \"doc(Add one to the given tensor.)doc\")\n",
    "# op_attr = register_op_attr(op_name, \"testing\", op_attrs.ScanopAttrs)\n",
    "_op.get(op_name).set_num_inputs(1)\n",
    "_op.get(op_name).add_argument(\"data\", \"Tensor\", \"The input tensor.\")\n",
    "# 添加关系函数\n",
    "_op.get(op_name).add_type_rel(\"Identity\") # \"Identity\", \"Broadcast\"\n",
    "_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)\n",
    "_op.register_stateful(op_name, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_op.get(op_name).get_attr(\"TOpPattern\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "free_var %x: Tensor[(10), float32];\n",
      "add_one(%x)\n"
     ]
    }
   ],
   "source": [
    "x = relay.var(\"x\", dtype=\"float32\", shape=(10,))\n",
    "\n",
    "custom_op = _op.get(\"add_one\")\n",
    "z = relay.Call(custom_op, [x])\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
